package mapp

import "gitee.com/kklt1996/data-structure/util"

type node struct {
	key         interface{}
	value       interface{}
	left, right *node
}

type BstMap struct {
	// bstMap 的根节点
	root *node

	// 元素元素个数
	size int

	/*
		使用comparator比较堆中元素大小,实现common.CompareAble接口和传入比较器形式二选一
		1 thisValue＞compareValue
		0 thisValue=compareValue
		-1 thisValue<compareValue
	*/
	comparator func(thisValue interface{}, compareValue interface{}) int
}

func (bstMap *BstMap) Remove(key interface{}) interface{} {
	var removeValue interface{}
	bstMap.root, removeValue = bstMap.removeElement(bstMap.root, key)
	return removeValue
}

func (bstMap *BstMap) Put(key interface{}, value interface{}) {
	bstMap.root = bstMap.put(bstMap.root, key, value)
}

func (bstMap BstMap) Get(key interface{}) interface{} {
	return bstMap.get(bstMap.root, key)
}

func (bstMap BstMap) Contains(key interface{}) bool {
	return bstMap.contains(bstMap.root, key)
}

func (bstMap BstMap) GetSize() int {
	return bstMap.size
}

func (bstMap BstMap) IsEmpty() bool {
	return bstMap.size == 0
}

/*
	新增或者修改
	key不存在就是新增,key存在就是修改
*/
func (bstMap *BstMap) put(root *node, key interface{}, value interface{}) *node {
	if root == nil {
		bstMap.size++
		return &node{key: key, value: value}
	} else {
		i := bstMap.compare(key, root.key)
		if i < 0 {
			root.left = bstMap.put(root.left, key, value)
		} else if i > 0 {
			root.right = bstMap.put(root.right, key, value)
		} else {
			// key相等的时候修改key对应的value
			root.value = value
		}
		return root
	}
}

func (bstMap BstMap) get(root *node, key interface{}) interface{} {
	if root == nil {
		return nil
	}
	i := bstMap.compare(key, root.key)
	if i == 0 {
		return root.value
	} else if i < 0 {
		return bstMap.get(root.left, key)
	} else {
		return bstMap.get(root.right, key)
	}
}

func (bstMap BstMap) contains(root *node, key interface{}) bool {
	if root == nil {
		return false
	}
	i := bstMap.compare(key, root.key)
	if i == 0 {
		return true
	} else if i < 0 {
		return bstMap.contains(root.left, key)
	} else {
		return bstMap.contains(root.right, key)
	}
}

func (bstMap *BstMap) removeMinimum(root *node) (*node, interface{}, interface{}) {
	if root.left == nil {
		rightTree := root.right
		root.right = nil
		bstMap.size--
		return rightTree, root.key, root.value
	} else {
		newRoot, minimumKey, minimumValue := bstMap.removeMinimum(root.left)
		root.left = newRoot
		return root, minimumKey, minimumValue
	}
}

func (bstMap *BstMap) removeElement(root *node, key interface{}) (*node, interface{}) {
	if root == nil {
		return nil, nil
	}
	i := bstMap.compare(key, root.key)
	if i == 0 {
		beforeRemoveValue := root.value
		if root.left == nil {
			rightTree := root.right
			root.right = nil
			bstMap.size--
			return rightTree, beforeRemoveValue
		}
		if root.right == nil {
			leftTree := root.left
			root.left = nil
			bstMap.size--
			return leftTree, beforeRemoveValue
		}
		rightNewRoot, removeKey, removeValue := bstMap.removeMinimum(root.right)
		root.right = rightNewRoot
		root.key = removeKey
		root.value = removeValue
		return root, beforeRemoveValue
	} else {
		var beforeRemoveValue interface{}
		if i < 0 {
			root.left, beforeRemoveValue = bstMap.removeElement(root.left, key)
		} else {
			root.right, beforeRemoveValue = bstMap.removeElement(root.right, key)
		}
		return root, beforeRemoveValue
	}
}

/*
	获取二分搜索树树的最大高度
*/
func (bstMap BstMap) MaxDepth() int {
	return bstMap.maxDepth(bstMap.root)
}

/*
	获取二分搜索树的最大高度
	返回当前二分搜索树的高度
*/
func (bstMap BstMap) maxDepth(root *node) int {
	if root == nil {
		return 0
	} else {
		leftTreeMaxDepth := bstMap.maxDepth(root.left)
		rightTreeMaxDepth := bstMap.maxDepth(root.right)
		if leftTreeMaxDepth > rightTreeMaxDepth {
			return leftTreeMaxDepth + 1
		} else {
			return rightTreeMaxDepth + 1
		}
	}
}

func (bstMap *BstMap) compare(thisValue interface{}, compareValue interface{}) int {
	if bstMap.comparator != nil {
		return bstMap.comparator(thisValue, compareValue)
	} else {
		return util.DefaultComparator(thisValue, compareValue)
	}
}
